Color Palette Extraction from Image

7. 3. 2022

One of my favorite things on facebook (I guess they are bigger on insta) is colorpalette.cinema, which is a channel dedicated to show the color composition in a palette of gorgeous stills from old and new movies alike. An example of this you can see below.
Automating the palette creation process seems easy (colorpalette.cinema (c.c) almost certainly does it too), so I decided I'd to try create a script that replicates their images, by inputting any image, extracts the main distinct colors and sorts them in a certain order. For this, I consider pixels of the image in the 3D space of their RGB values and use clustering to retrieve 10 distinct colors (standard at c.c) and places them in a spectrum.
I use 2 clustering types to compare the results and sometimes compare to the palette from c.c. It turns out that color is difficult and especially to computationally sort colors in an intuitive way. So for final palettes, in my implementation the colors need to be sorted manually in the most visually appealing way.

164965241_1326630261053293_1553676689085557570_n.jpg

In [1]:
#Importing packages
from itertools import combinations

import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D
from PIL import Image

from sklearn.cluster import KMeans, AgglomerativeClustering

Data from image

The Pillow package is great at handling images, get the RGB values of each pixel and place them in an array. I can plot the data points in a 3D coordinate system afterwards.

In [2]:
def image_to_numpy(filename):
    image = Image.open(filename)
    data = np.asarray(image)
    data = data[:,:,0:3] #remove 4th col if rgba
    
    y_size = data.shape[0]
    x_size = data.shape[1]
    
    #need to cut from size of notebook XD
    scaling_param = int((x_size/800 + y_size/600) / 2) + 1
    height_list = [i for i in range(0, y_size, scaling_param)] 
    width_list = [i for i in range(0, x_size, scaling_param)]

    data = data[height_list, :, :]
    data = data[:, width_list, :]
    data2d = np.reshape(data,(len(height_list)*len(width_list), 3)) 
    
    return data, data2d
In [3]:
def scatterplot_rgb(data2d, labels=None):
    fig = plt.figure(figsize=(12,9))
    ax = fig.gca(projection='3d')

    x = data2d[:,0]
    y = data2d[:,1]
    z = data2d[:,2]
    
    if labels is not None:
        ax.scatter(x, y, z, c=labels.astype(float), edgecolor="k")
    else:
        ax.scatter(x, y, z, edgecolor="k")        
    
    ax.set_xlabel('Red Value')
    ax.set_ylabel('Green Value')
    ax.set_zlabel('Blue Value')

    plt.show()

Reducing image size

I found that k-means clustering can handle data sizes in the millions and cluster them in a reasonable time (couple minutes). But agglomerative clustering would need as much RAM that only computing clusters have, so I needed to reduce the sample size. In this approach I reduce images around 200x100 pixels, by taking every n-th pixel and leave out the rest. Arguably I lose some insight, but overall images with this reduced resolution still have enough info about the main colors that I want to visualize.

In [4]:
def reduce_image_res(filename):
    image = Image.open(filename)
    data = np.asarray(image)
    data = data[:,:,0:3] #remove 4th col if rgba
    
    y_size = data.shape[0]
    x_size = data.shape[1]
    
    #reduce image to ~200x100 pixel size
    scaling_param = int((x_size/200 + y_size/100) / 2)
    
    height_list = [i for i in range(0, y_size, scaling_param)] 
    width_list = [i for i in range(0, x_size, scaling_param)]
    
    data = data[height_list, :, :] #remove 4th col too if rgba
    data = data[:, width_list, :]
    data2d = np.reshape(data,(len(height_list)*len(width_list), 3))
    
    return data, data2d, scaling_param
In [5]:
filename = 'elden.png'
data, data2d, scaling_param = reduce_image_res(filename)
reduced_image = Image.fromarray(data)
display(reduced_image)

Clustering methods

K-means clustering is the default clustering method for almost everything and works well for me too as it can handle large sample sizes and is fast. I tried a couple others, but agglomerative clustering worked best, especially with the complete linkage option. Linkage determines the distance measure by it joins clusters together, until the specified n amount of clusters remain. Complete linkage pushes cluster components as far away as possible, which is perfect for this project.

In [6]:
def n_kmeans(data2d, n):
    
    est = KMeans(n_clusters=n)
    est.fit(data2d)
    labels = est.labels_
    centers = est.cluster_centers_
        
    return centers, labels
In [7]:
def agglomerative(data2d, n, linkage):
    """Only to use with small images, 200x100 pixel max!"""
    
    est = AgglomerativeClustering(n_clusters=n, linkage=linkage)
    est.fit(data2d)
    labels = est.labels_
    
    #calculate centroids ourselves
    labels_resh = np.array(labels).reshape((len(labels), 1))
    data2d = np.append(data2d, labels_resh, axis=1)
    
    n_points = [0 for i in range(n)]
    y_sum = [0 for i in range(n)]
    x_sum = [0 for i in range(n)]
    z_sum = [0 for i in range(n)]
    
    for point_i in range(len(data2d)):
        point = data2d[point_i]
        
        n_points[point[3]] += 1
        y_sum[point[3]] += point[0]
        x_sum[point[3]] += point[1]
        z_sum[point[3]] += point[2]
    
    centers = []
    for cluster in range(n):
        center = [y_sum[cluster]/n_points[cluster], x_sum[cluster]/n_points[cluster], z_sum[cluster]/n_points[cluster]]
        centers.append(center)
    centers = np.asarray(centers)

    data2d = data2d[:, 0:3]
    
    return centers, labels

Get image from cluster centroids

I want to have automatic sorting methods and manual one too. Dark_to_light adds up the 3 rgb values for each point and I just need to sort them that way. Small numbers naturally correspond to dark and high values to light colors.

In [8]:
def sort_colors(centers, method):
    """
    
    Either by finding along which color is the largest separation and sorting by that color, or by
    sorting from darkest to lightest colors.
    """
    
    if method == None:
        centers_sorted = centers
    elif method == "red":
        centers_sorted = centers[centers[:, 0].argsort()]
    elif method == "green":
        centers_sorted = centers[centers[:, 1].argsort()]
    elif method == "blue":
        centers_sorted = centers[centers[:, 2].argsort()]
    
    elif method == "along_color":
        
        dist_sums = []
        for color in range(3):
            color_col = centers[:, color]
            pairs = list(combinations(color_col, 2)) #get every pair
            
            dist_sum = 0
            for a, b in pairs:
                dist_sum += np.abs(a - b)
            dist_sums.append(dist_sum)
        
        #getting sorting order
        sortby = dist_sums.index(max(dist_sums))
        
        centers_sorted = centers[centers[:, sortby].argsort()]
                
    elif method == "dark_to_light":
        
        color_sorter = []
        for rgb in centers:
            color_sorter.append(sum(rgb)) #just summing of r/g/b values, smaller values are darker

        color_sorter = np.array(color_sorter).reshape((len(color_sorter), 1))
        centers = np.append(centers, color_sorter, axis=1)

        centers_sorted = centers[centers[:, 3].argsort()] #sorting in ascending orderé
        centers_sorted = centers_sorted[:, 0:3] #cut last col
    
    #to manually sort, as method we can put a list with the center indices in order
    elif type(method) == list and len(method) == len(centers):
        centers_sorted = centers[method]
    
    else:
        print("Non-existant method")
        
    return centers_sorted
In [9]:
def image_from_centers(data, centers_sorted):

    centers_sorted = np.asarray(centers_sorted).astype(int)
    n = len(centers_sorted)
        
    #create new array with personalized sizes, default color value white?
    mc_ratio = 9 #large in colorpalette cinema
    
    height = data.shape[0]
    width = data.shape[1]
    margin = width/(n*mc_ratio+n+1)
    
    palette = np.full((int(margin*(mc_ratio+2)), width, 3), 255, dtype='uint8')
    
    #calculate regions to fill with color, fill them with sorted centers
    for color in range(n):
        curr_col = centers_sorted[color]
        
        for h in range(int(margin), int(margin*(mc_ratio+1)), 1):
            for w in range(int((color+1)*margin*(mc_ratio+1) - margin*mc_ratio), int((color+1)*margin*(mc_ratio+1)), 1):
                palette[h][w] = curr_col
    
    #attach to data array
    full_arr = np.append(data, palette, axis=0)
    
    image_new = Image.fromarray(full_arr)
    display(image_new)

Master functions

One controlling function will show plots before the palette, the other will not.

In [10]:
def palette_plots(filename, n, clustering, method=""):
    """Master function"""
    
    data, data2d = image_to_numpy(filename)
    print(f"Number of pixels in original image: {data2d.shape[0]}")
    data_red, data2d_red, scaling_param = reduce_image_res(filename)
    print(f"Take every {scaling_param}-th pixel")
    print(f"Number of data points in reduced image: {data2d_red.shape[0]}")
    
    if clustering == "k-means":
        centers, labels = n_kmeans(data2d_red, n)
    elif clustering == "agglomerative-ward":
        centers, labels = agglomerative(data2d_red, n, 'ward')
    elif clustering == "agglomerative-complete":
        centers, labels = agglomerative(data2d_red, n, 'complete')
    elif clustering == "agglomerative-single":
        centers, labels = agglomerative(data2d_red, n, 'single')
        
    scatterplot_rgb(data2d_red, labels)
    
    print("Cluster centers:")
    scatterplot_rgb(centers)
    
    if method != "":
        centers = sort_colors(centers, method=method)
    
    image_from_centers(data, centers)
    
    return data, centers
In [11]:
def palette_only(filename, n, clustering, method=""):
    """Master function"""
    
    data, data2d = image_to_numpy(filename)
    data_red, data2d_red, scaling_param = reduce_image_res(filename)
    
    if clustering == "k-means":
        centers, labels = n_kmeans(data2d_red, n)
    elif clustering == "agglomerative-ward":
        centers, labels = agglomerative(data2d_red, n, 'ward')
    elif clustering == "agglomerative-complete":
        centers, labels = agglomerative(data2d_red, n, 'complete')
    elif clustering == "agglomerative-single":
        centers, labels = agglomerative(data2d_red, n, 'single')
    
    if method != "":
        centers = sort_colors(centers, method=method)
    
    image_from_centers(data, centers)
    
    return data, centers

Example Images/Comparisons

Blade Runner

Here I show clustering with all 4 methods I implemented and then compare to c.c's version. K-means, even though a simple clustering method, works pretty well, that's why it's very reliable. In terms of agglomerative, complete linkage performs the best, this will be further proven later. Single linkage just doesn't work for fairly homogenous data like this, it seems.
C.c has a very vibrant pink, which agglomerative-complete comes close to, and a light green, which is not captured by either method interestingly. Maybe reducing the resolution in my images comes into play here, or c.c manually selects colors, after all.
K-means clustering also finishes ~10x faster than agglomerative clustering.

In [12]:
%%time

data, centers = palette_plots('BladeRunner-Neon-1024x544.jpg', 10, "k-means", method="blue")
Number of pixels in original image: 139264
Take every 5-th pixel
Number of data points in reduced image: 22345
Cluster centers:
Wall time: 3.05 s
In [13]:
%%time

data, centers = palette_plots('BladeRunner-Neon-1024x544.jpg', 10, "agglomerative-complete", method="blue")
Number of pixels in original image: 139264
Take every 5-th pixel
Number of data points in reduced image: 22345
Cluster centers:
Wall time: 24.9 s
In [14]:
data, centers = palette_plots('BladeRunner-Neon-1024x544.jpg', 10, "agglomerative-ward", method="blue")
Number of pixels in original image: 139264
Take every 5-th pixel
Number of data points in reduced image: 22345
Cluster centers:
In [15]:
data, centers = palette_plots('BladeRunner-Neon-1024x544.jpg', 10, "agglomerative-single", method="blue")
Number of pixels in original image: 139264
Take every 5-th pixel
Number of data points in reduced image: 22345
Cluster centers:
In [16]:
#c.c ver
image = Image.open("274535651_1569664613416522_6747045973204349058_n.jpg")
display(image)

2001: Space Odyssey

This image was very important for me to get right, as the blue light and blue eyes are the pivotal component of the image. Was very happy that, imho, agglomerative clustering outperforms c.c-s workflow.

In [17]:
%%time

data, centers = palette_plots('space_odyssey.jpg', 10, "agglomerative-complete", method="red")
Number of pixels in original image: 262448
Take every 7-th pixel
Number of data points in reduced image: 21600
Cluster centers:
Wall time: 13 s
In [18]:
centers_sorted = sort_colors(centers, method=[0, 1, 2, 4, 5, 9, 8, 7, 6, 3])
image_from_centers(data, centers_sorted)
In [19]:
#compare to c.c
image = Image.open("273002620_1556192251430425_3005684082829387586_n.jpg")
display(image)

Closer (2004)

In [20]:
data, centers = palette_plots('natalie_portman.jpg', 10, "k-means", method="dark_to_light")
Number of pixels in original image: 303840
Take every 3-th pixel
Number of data points in reduced image: 33840
Cluster centers:
In [21]:
data, centers = palette_only('natalie_portman.jpg', 10, "agglomerative-complete", method="dark_to_light")

Elden Ring

In [22]:
data, centers = palette_plots('elden.png', 10, "k-means", method="dark_to_light")
Number of pixels in original image: 230400
Take every 10-th pixel
Number of data points in reduced image: 20736
Cluster centers:
In [23]:
data, centers = palette_only('elden.png', 10, "agglomerative-complete", method="dark_to_light")

#some vibrant orange would be great to pick, but is missed here too
In [24]:
#sort palette to be nicer

centers = sort_colors(centers, method=[0, 1, 2, 4, 7, 9, 8, 6, 5, 3])
image_from_centers(data, centers)

The following is my own screenshot from Elden Ring, interestingly, agglomerative clustering doesn't find enough distinct clusters and the cluster centroids are just colors that don't really come up on the image. Better with k-means and lower cluster size.

In [25]:
data, centers = palette_plots('EldenRing_scr1.png', 10, "agglomerative-complete", method="dark_to_light")
Number of pixels in original image: 230400
Take every 10-th pixel
Number of data points in reduced image: 20736
Cluster centers:
In [26]:
data, centers = palette_only('EldenRing_scr1.png', 6, "agglomerative-complete", method="dark_to_light")
In [27]:
data, centers = palette_only('EldenRing_scr1.png', 10, "k-means", method="dark_to_light")

City of God

In [28]:
data, centers = palette_plots('city_of_god.jpg', 10, "k-means", method="dark_to_light")
Number of pixels in original image: 415200
Take every 9-th pixel
Number of data points in reduced image: 20648
Cluster centers:
In [29]:
data, centers = palette_only('city_of_god.jpg', 10, "agglomerative-complete", method="dark_to_light")

A very strong example is the next artwork from Yuan Yuan. A huge variety of blues are present almost homogenously along a spectrum and clustering just breaks up the flow equally. Other than blues, a patch of dispersed warm colors are also present which will be grouped into a single group with k-means and multiple reds/oranges with agglomerative cl.

In [30]:
data, centers = palette_plots('yuan-yuan-bluesity0318.jpg', 10, "k-means", method="dark_to_light")
Number of pixels in original image: 355200
Take every 8-th pixel
Number of data points in reduced image: 22320
Cluster centers:
In [31]:
data, centers = palette_plots('yuan-yuan-bluesity0318.jpg', 10, "agglomerative-complete")
Number of pixels in original image: 355200
Take every 8-th pixel
Number of data points in reduced image: 22320
Cluster centers:
In [32]:
#sorted
centers_sorted = sort_colors(centers, method=[8, 1, 6, 0, 3, 5, 4, 2, 9, 7])
image_from_centers(data, centers_sorted)

Belle (2021)

Belle is also a beautiful movie and the main characters' duality is represented with duality in color, which can be well captured in the palette.

In [33]:
data, centers = palette_plots('BELLEE.jpg', 10, "agglomerative-complete")
Number of pixels in original image: 230400
Take every 10-th pixel
Number of data points in reduced image: 20736
Cluster centers:
In [34]:
#nice palette

centers_sorted = sort_colors(centers, method=[8, 5, 0, 6, 3, 7, 1, 9, 2, 4])
image_from_centers(data, centers_sorted)

Dune (2021)

In [35]:
data, centers = palette_plots('dune_zendaya.jpg', 10, "k-means", method="dark_to_light")
Number of pixels in original image: 385920
Take every 17-th pixel
Number of data points in reduced image: 21470
Cluster centers:
In [36]:
data, centers = palette_plots('dune_zendaya.jpg', 10, "agglomerative-complete", method="dark_to_light")
Number of pixels in original image: 385920
Take every 17-th pixel
Number of data points in reduced image: 21470
Cluster centers:
In [37]:
data, centers = palette_plots('dune_rebecca.jpg', 10, "agglomerative-complete", method="dark_to_light")
Number of pixels in original image: 385920
Take every 17-th pixel
Number of data points in reduced image: 21470
Cluster centers:

The Dark Knight

Agglomerative clustering neatly spots the vibrant blue/teal of the police car's signal, which k-means misses.

In [43]:
data, centers = palette_plots("dark_knight_joker.webp", 10, "k-means")
Number of pixels in original image: 125195
Take every 5-th pixel
Number of data points in reduced image: 20090
Cluster centers:
In [44]:
centers_sorted = sort_colors(centers, method=[0, 5, 4, 9, 2, 3, 1, 7, 6, 8])
image_from_centers(data, centers_sorted)
In [40]:
data, centers = palette_plots("dark_knight_joker.webp", 10, "agglomerative-complete")
Number of pixels in original image: 125195
Take every 5-th pixel
Number of data points in reduced image: 20090
Cluster centers:
In [41]:
#sort it 
centers = sort_colors(centers, method=[1, 9, 5, 3, 0, 2, 7, 8, 4, 6])
image_from_centers(data, centers)